import numpy as np
import torch
import random
from torch.nn import functional as F



def update_GPM(task_id, model, mat_list, threshold, feature_list=[], proj=None, every_task_base=None):
    
    print('Threshold: ', threshold)
    if not feature_list:
        # After First Task
        for i in range(len(mat_list)):
            activation = mat_list[i]
            U, S, Vh = np.linalg.svd(activation, full_matrices=False)
            # criteria (Eq-5)
            sval_total = (S**2).sum()
            sval_ratio = (S**2) / sval_total
            r = np.sum(np.cumsum(sval_ratio) < threshold[i])  # +1
            feature_list.append(U[:, 0:r])
            proj[task_id][i] = U[:, 0:r]
            every_task_base[task_id][i] = U[:, 0:r]
    else:
        for i in range(len(mat_list)):
            activation = mat_list[i]
            U1, S1, Vh1 = np.linalg.svd(activation, full_matrices=False)
            sval_total = (S1**2).sum()
            sval_ratio = (S1**2)/sval_total
            r = np.sum(np.cumsum(sval_ratio) < threshold[i])  # +1
            every_task_base[task_id][i] = U1[:, 0:r]

            act_hat = activation - np.dot(
                np.dot(feature_list[i], feature_list[i].transpose()),
                activation)
            U, S, Vh = np.linalg.svd(act_hat, full_matrices=False)
            # criteria (Eq-9)
            sval_hat = (S**2).sum()
            sval_ratio = (S**2) / sval_total
            accumulated_sval = (sval_total - sval_hat) / sval_total

            r = 0
            for ii in range(sval_ratio.shape[0]):
                if accumulated_sval < threshold[i]:
                    accumulated_sval += sval_ratio[ii]
                    r += 1
                else:
                    break
            if r != 0:
                print('Skip Updating GPM for layer: {}'.format(i + 1))
                Ui = np.hstack((feature_list[i], U[:, 0:r]))
                if Ui.shape[1] > Ui.shape[0]:
                    print('-' * 40)
                    print('Base Matrix has OOM')
                    print('-' * 40)
                    feature_list[i] = Ui[:, 0:Ui.shape[0]]
                else:
                    feature_list[i] = Ui
            if r == 0:
                proj[task_id][i] = proj[task_id-1][i]
            else:
                proj[task_id][i] = U[:, 0:r]

    print('-' * 40)
    print('Gradient Constraints Summary')
    print('-' * 40)
    for i in range(len(feature_list)):
        print('Layer {} : {}/{}'.format(i + 1, feature_list[i].shape[1],
                                        feature_list[i].shape[0]))
    print('-' * 40)
    return feature_list



def train(args, epoch, task_id, model, device, x, y, optimizer, criterion):
    model.train()
    r = np.arange(x.size(0))
    np.random.shuffle(r)
    r = torch.LongTensor(r).to(device)
    # Loop batches
    for i in range(0, len(r), args.batch_size_train):
        if i+args.batch_size_train <= len(r):
            b = r[i:i+args.batch_size_train]
        else:
            b = r[i:]
        #如果是单通道
        if x.shape[1] == 1:
            data = x[b].view(len(b), -1)
        else:
            data = x[b]
        data, target = data.to(device), y[b].to(device)
        optimizer.zero_grad()
        output = model(data, task_id, None, -1)
        #如果output是多头的话，需要加上task_id
        if args.multi_head:
            output = output[task_id]
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()



def test(args, model, device, x, y, criterion, task_id):
    model.eval()
    total_loss = 0
    total_num = 0
    correct = 0
    r = np.arange(x.size(0))
    np.random.shuffle(r)
    r = torch.LongTensor(r).to(device)
    with torch.no_grad():
        # Loop batches
        for i in range(0, len(r), args.batch_size_test):
            if i + args.batch_size_test <= len(r):
                b = r[i:i + args.batch_size_test]
            else:
                b = r[i:]
            #如果是单通道
            if x.shape[1] == 1:
                data = x[b].view(len(b), -1)
            else:
                data = x[b]
            data, target = data.to(device), y[b].to(device)
            output = model(data, task_id, None, -1)
            #如果output是多头的话，需要加上task_id
            if args.multi_head:
                output = output[task_id]
            loss = criterion(output, target)
            pred = output.argmax(dim=1, keepdim=True)

            correct += pred.eq(target.view_as(pred)).sum().item()
            total_loss += loss.data.cpu().numpy().item() * len(b)
            total_num += len(b)

    acc = 100. * correct / total_num
    final_loss = total_loss / total_num
    return final_loss, acc




def contrast_cls(every_task_base, sim_tasks,
                 sim_scores, model, task_id, device, criterion):
    l2 = 0
    cnt = 0
    list_keys = list(model.act.keys())
    for k, (m, params) in enumerate(model.named_parameters()):
        # 第 cnt 层
        sim = []
        # sim.append(params.flatten())

        for tt in range(task_id):
            sz = params.grad.data.size(0)
            tmp = torch.FloatTensor(every_task_base[tt][cnt]).to(device)
            norm_project = torch.mm(tmp, tmp.transpose(1, 0)).to(device)
            if "conv" in m:
                proj_weight = torch.mm(params.view(sz, -1),
                                       norm_project).view(params.size())
            else:
                proj_weight = torch.mm(params, norm_project)
            sim.append(proj_weight)
            dis = list(set(range(task_id)) - set(sim_tasks[cnt]))
            tt = random.sample(dis, 1)[0]
            tmp = torch.FloatTensor(every_task_base[tt][cnt]).to(device)
            norm_project = torch.mm(tmp, tmp.transpose(1, 0))
            proj_weight = torch.mm(params.view(sz, -1),
                                    norm_project).view(params.size())
            sim.append(proj_weight)
            if len(sim) >= 4:
                break

        sim = torch.stack(sim).view(4, -1)
        if sum(sim_scores[cnt]) != 2:
            idxs = torch.arange(0, sim.shape[0], device=device)
            y_true = idxs + 1 - idxs % 2 * 2
            similarities = F.cosine_similarity(sim.unsqueeze(1), sim.unsqueeze(0), dim=2)

            similarities = similarities - torch.eye(sim.shape[0], device=device) * 1e12
            similarities = similarities / 0.05

            loss = F.cross_entropy(similarities, y_true)
            l2 += torch.mean(loss)

        cnt += 1

    return l2

def train_projected(args, p, model, device, x, y, optimizer, criterion, feature_mat, task_id, epoch, sim_tasks, sim_scores, every_task_base):
    model.train()
    r = np.arange(x.size(0))
    np.random.shuffle(r)
    r = torch.LongTensor(r).to(device)
    # Loop batches
    for i in range(0, len(r), args.batch_size_train):
        if i+args.batch_size_train <= len(r):
            b = r[i:i+args.batch_size_train]
        else:
            b = r[i:]
        if x.shape[1] == 1:
            data = x[b].view(len(b), -1)
        else:
            data = x[b]
        data, target = data.to(device), y[b].to(device)
        optimizer.zero_grad()
        output = model(data, task_id, p, epoch=epoch + i)
        if args.multi_head:
            output = output[task_id]
        loss = criterion(output, target)

        if len(sim_tasks) != 0:
            l2 = contrast_cls(every_task_base, sim_tasks,
                              sim_scores, model, task_id, device, criterion)
            loss += l2

        loss.backward()
        # Gradient Projections
        kk = 0
        for k, (m, params) in enumerate(model.named_parameters()):
            if k < 3 and len(params.size()) != 1:
                sz = params.grad.data.size(0)
                params.grad.data = params.grad.data - torch.mm(params.grad.data.view(sz, -1),
                                                               feature_mat[kk]).view(params.size())
                kk += 1
            elif (k < 3 and len(params.size()) == 1) and task_id != 0:
                params.grad.data.fill_(0)

        optimizer.step()


